package com.mofum.scope.processor.impl;

import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.statement.SQLSelect;
import com.alibaba.druid.sql.ast.statement.SQLSelectQueryBlock;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.dialect.db2.visitor.DB2SchemaStatVisitor;
import com.alibaba.druid.sql.dialect.h2.visitor.H2SchemaStatVisitor;
import com.alibaba.druid.sql.dialect.hive.visitor.HiveSchemaStatVisitor;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
import com.alibaba.druid.sql.dialect.odps.visitor.OdpsSchemaStatVisitor;
import com.alibaba.druid.sql.dialect.oracle.visitor.OracleSchemaStatVisitor;
import com.alibaba.druid.sql.dialect.phoenix.visitor.PhoenixSchemaStatVisitor;
import com.alibaba.druid.sql.dialect.postgresql.visitor.PGSchemaStatVisitor;
import com.alibaba.druid.sql.dialect.sqlserver.visitor.SQLServerSchemaStatVisitor;
import com.alibaba.druid.sql.visitor.SchemaStatVisitor;
import com.alibaba.druid.util.JdbcConstants;
import com.alibaba.druid.util.StringUtils;
import com.mofum.scope.common.compatible.IScopeCompatibleListener;
import com.mofum.scope.common.model.Permission;
import com.mofum.scope.common.model.Scope;
import com.mofum.scope.config.ScopeConfig;
import com.mofum.scope.processor.IRestructureProcessor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * Druid重构处理器
 *
 * @author developer@omuao.com
 * @since 2019-04-26
 **/
public class DruidRestructureProcessor implements IRestructureProcessor {

    /**
     * 数据库
     */
    private String dbType;

    /**
     * 兼容监听器
     */
    private IScopeCompatibleListener listener;

    @Override
    public BoundSql process(MappedStatement ms, BoundSql permissionSql, Permission permission) {

        return restructureSql(ms, permission, permissionSql);
    }

    private BoundSql restructureSql(MappedStatement ms, Permission permission, BoundSql permissionSql) {
        String paramsSql = permissionSql.getSql();

        if (paramsSql != null) {

            paramsSql = paramsSql.replaceAll("(?s)<\\!\\-\\-.+?\\-\\->", "");
        }

        try {
            List<SQLStatement> stmtList = SQLUtils.parseStatements(paramsSql, getDbType());
            SQLStatement sqlStatement = stmtList == null ? null : stmtList.get(0);
            if (sqlStatement == null) {
                return permissionSql;
            }

            SchemaStatVisitor schemaStatVisitor = null;
            switch (getDbType()) {
                case JdbcConstants.MYSQL:
                    schemaStatVisitor = new MySqlSchemaStatVisitor();
                    break;
                case JdbcConstants.DB2:
                    schemaStatVisitor = new DB2SchemaStatVisitor();
                    break;
                case JdbcConstants.ORACLE:
                    schemaStatVisitor = new OracleSchemaStatVisitor();
                    break;
                case JdbcConstants.SQL_SERVER:
                    schemaStatVisitor = new SQLServerSchemaStatVisitor();
                    break;
                case JdbcConstants.POSTGRESQL:
                    schemaStatVisitor = new PGSchemaStatVisitor();
                    break;
                case JdbcConstants.ODPS:
                    schemaStatVisitor = new OdpsSchemaStatVisitor();
                    break;
                case JdbcConstants.H2:
                    schemaStatVisitor = new H2SchemaStatVisitor();
                    break;
                case JdbcConstants.MARIADB:
                    schemaStatVisitor = new MySqlSchemaStatVisitor();
                    break;
                case JdbcConstants.HIVE:
                    schemaStatVisitor = new HiveSchemaStatVisitor();
                    break;
                case JdbcConstants.PHOENIX:
                    schemaStatVisitor = new PhoenixSchemaStatVisitor();
                    break;
                default:
                    schemaStatVisitor = new SchemaStatVisitor();
                    break;

            }
            SQLSelectStatement sqlSelectStatement = (SQLSelectStatement) sqlStatement;
            sqlSelectStatement.accept(schemaStatVisitor);
            SQLSelect sqlSelect = sqlSelectStatement.getSelect();
            SQLSelectQueryBlock sqlSelectQueryBlock = (SQLSelectQueryBlock) sqlSelect.getQuery();

            BoundSql newSql = null;
            if (listener != null) {

                String sql = listener.compatible(sqlSelect.toString(), getDbType(), permission, ms);

                stmtList = SQLUtils.parseStatements(sql, getDbType());
                sqlStatement = stmtList == null ? null : stmtList.get(0);
                if (sqlStatement == null) {
                    return permissionSql;
                }
                sqlSelectStatement = (SQLSelectStatement) sqlStatement;
                sqlSelectStatement.accept(schemaStatVisitor);
                sqlSelect = sqlSelectStatement.getSelect();
                newSql = new BoundSql(ms.getConfiguration(), sqlSelect.toString(), permissionSql.getParameterMappings(), permissionSql.getParameterObject());
                return newSql;
            }

//            if (ms.getId().endsWith("_COUNT")) {
//                SQLSubqueryTableSource sqlSubqueryTableSource = (SQLSubqueryTableSource) sqlSelect.getQueryBlock().getFrom();
//                sqlSelectQueryBlock = sqlSubqueryTableSource.getSelect().getQueryBlock();
//            }

            //包装权限内容
            wrapperPermission(sqlSelectQueryBlock, permission, getDbType());

            newSql = new BoundSql(ms.getConfiguration(), sqlSelect.toString(), permissionSql.getParameterMappings(), permissionSql.getParameterObject());

            return newSql;

        } catch (Exception e) {
            throw new RuntimeException(e.getMessage(), e);
        }
    }

    /**
     * 包装权限内容SQL
     *
     * @param resultPlainSelect 结果标准查询
     * @param permission
     */
    public static void wrapperPermission(SQLSelectQueryBlock resultPlainSelect, Permission permission, String dbType) {
        //权限为空不拼装
        if (permission == null || permission.getScopeCollections() == null || permission.getScopeCollections().size() == 0) {
            return;
        }

        List<Scope> scopes = permission.getScopeCollections();


        //Table  --  范围
        Map<String, List<Scope>> tableMap = new HashMap<>();

        for (Scope scope : scopes) {

            List<Scope> list = null;

            if (scope.getTableName() == null) {
                list = tableMap.get("");

            } else {
                list = tableMap.get(scope.getTableName());
            }

            if (list == null) {
                list = new ArrayList<>();
                tableMap.put(scope.getTableName(), list);
            }

            list.add(scope);

        }

        for (String table : tableMap.keySet()) {

            wrapperScope(resultPlainSelect, table, tableMap.get(table), dbType);

        }

    }

    /**
     * 包装范围
     */
    public static void wrapperScope(SQLSelectQueryBlock resultPlainSelect, String table, List<Scope> scopes, String dbType) {

        // Column - Scopes
        Map<String, List<String>> scopesMap = new HashMap<>();

        for (Scope scope : scopes) {

            if (scope.getColumn() == null || scope.getColumn().trim().equals("")) {
                continue;
            }

            List<String> data = scopesMap.get(scope.getColumn());

            if (data == null) {
                List<String> list = new ArrayList<>();
                scopesMap.put(scope.getColumn(), list);
            }
            data = scopesMap.get(scope.getColumn());

            data.add(scope.getId());
        }

        for (String column : scopesMap.keySet()) {

            wrapperScopes(resultPlainSelect, table, column, scopesMap.get(column), dbType);

        }
    }

    /**
     * 包装范围信息
     */
    public static void wrapperScopes(SQLSelectQueryBlock resultPlainSelect, String table, String column, List<String> strings, String dbType) {
        //拼装条件

        SQLExpr sqlExpr = null;
        if (table != null && !StringUtils.isEmpty(table)) {
            sqlExpr = inExpression(String.format("%s.%s", table, column), strings);
        } else {
            sqlExpr = inExpression(column, strings);
        }

        SQLExpr whereExpr = resultPlainSelect.getWhere();

        SQLExpr andExpression = andExpression(sqlExpr, whereExpr, dbType);

        resultPlainSelect.setWhere(andExpression);

    }

    /**
     * IN 表达式
     */
    private static SQLExpr inExpression(String left, List<String> rightItems) {
        if (left == null) {
            left = "";
        }

        if (rightItems == null) {
            rightItems = new ArrayList<>();
        }

        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append(left).append(" in ").append("(");
        for (String item : rightItems) {
            stringBuilder.append(String.format("'%s',", item));
        }
        stringBuilder.deleteCharAt(stringBuilder.length() - 1);
        stringBuilder.append(")");
        return SQLUtils.toSQLExpr(stringBuilder.toString());
    }

    /**
     * AND 表达式
     */
    private static SQLExpr andExpression(SQLExpr left, SQLExpr right, String dbType) {
        StringBuilder stringBuilder = new StringBuilder();

        if (left == null && right == null) {
            return null;
        } else if (right == null) {
            stringBuilder.append(sqlExprString(left, dbType));
        } else if (left == null) {
            stringBuilder.append(sqlExprString(right, dbType));
        } else {
            stringBuilder.append("(").append(sqlExprString(right, dbType)).append(") and ")
                    .append(sqlExprString(left, dbType));
        }
        return SQLUtils.toSQLExpr(stringBuilder.toString(), dbType);
    }

    /**
     * SQL表达式转字符串
     */
    private static String sqlExprString(SQLExpr sqlExpr, String dbType) {
        return SQLUtils.toSQLString(sqlExpr, dbType);
    }

    public String getDbType() {
        if (dbType == null) {
            dbType = ScopeConfig.DEFAULT_DRUID_DB_TYPE;
        }
        return dbType;
    }

    public void setDbType(String dbType) {
        this.dbType = dbType;
    }

    public static void main(String[] args) {
        DruidRestructureProcessor defaultRestructureProcessor = new DruidRestructureProcessor();

        //语句
        SQLSelectStatement originSelect = (SQLSelectStatement) SQLUtils.parseStatements("select * from tb Order By col desc Limit 1", JdbcConstants.MYSQL).get(0);

        SQLSelect originPlainSelect = originSelect.getSelect();

        SQLSelectQueryBlock sqlSelectQueryBlock = (SQLSelectQueryBlock) originPlainSelect.getQuery();

        Permission permission = new Permission();

        List<Scope> scopes = new ArrayList<>();

        Scope scope = new Scope();

        scope.setTableName("");

        scope.setColumn("wmsId");

        scope.setId("1235467");

        Scope scope2 = new Scope();

        scope2.setTableName("t2");

        scope2.setColumn("shopId");

        scope2.setId("1235467");

        scopes.add(scope);
        scopes.add(scope2);

        permission.setScopeCollections(scopes);

        defaultRestructureProcessor.wrapperPermission(sqlSelectQueryBlock, permission, JdbcConstants.MYSQL);

        System.out.println(originPlainSelect.toString());
    }

    @Override
    public void setCompatibleListener(IScopeCompatibleListener listener) {
        this.setListener(listener);
    }

    public IScopeCompatibleListener getListener() {
        return listener;
    }

    public void setListener(IScopeCompatibleListener listener) {
        this.listener = listener;
    }
}
